# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn
import numpy as np
from loguru import logger
from models.experimental.vadv2.reference import vad
from models.experimental.vadv2.tt import tt_vad
from models.experimental.vadv2.reference.base_box3d import LiDARInstance3DBoxes
from models.experimental.vadv2.tt.model_preprocessing import (
    create_vadv2_model_parameters_vad,
)
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.experimental.vadv2.common import load_torch_model


@pytest.mark.parametrize("device_params", [{"l1_small_size": 20 * 1024}], indirect=True)
def test_vadv2(
    device,
    reset_seeds,
    model_location_generator,
):
    torch_model = vad.VAD(
        use_grid_mask=True,
        pts_voxel_layer=None,
        pts_voxel_encoder=None,
        pts_middle_encoder=None,
        pts_fusion_layer=None,
        img_backbone=True,
        pts_backbone=None,
        img_neck=True,
        pts_neck=None,
        pts_bbox_head=True,
        img_roi_head=None,
        img_rpn_head=None,
        train_cfg=None,
        test_cfg=None,
        pretrained=None,
        video_test_mode=True,
        fut_ts=6,
        fut_mode=6,
    )

    torch_model = load_torch_model(torch_model=torch_model, model_location_generator=model_location_generator)

    input_dict = {
        "img_metas": [
            [
                [
                    {
                        "filename": [
                            "./data/nuscenes/samples/CAM_FRONT/n008-2018-08-01-15-16-36-0400__CAM_FRONT__1533151603512404.jpg",
                            "./data/nuscenes/samples/CAM_FRONT_RIGHT/n008-2018-08-01-15-16-36-0400__CAM_FRONT_RIGHT__1533151603520482.jpg",
                            "./data/nuscenes/samples/CAM_FRONT_LEFT/n008-2018-08-01-15-16-36-0400__CAM_FRONT_LEFT__1533151603504799.jpg",
                            "./data/nuscenes/samples/CAM_BACK/n008-2018-08-01-15-16-36-0400__CAM_BACK__1533151603537558.jpg",
                            "./data/nuscenes/samples/CAM_BACK_LEFT/n008-2018-08-01-15-16-36-0400__CAM_BACK_LEFT__1533151603547405.jpg",
                            "./data/nuscenes/samples/CAM_BACK_RIGHT/n008-2018-08-01-15-16-36-0400__CAM_BACK_RIGHT__1533151603528113.jpg",
                        ],
                        "ori_shape": [(360, 640, 3)] * 6,
                        "img_shape": [(384, 640, 3)] * 6,
                        "lidar2img": [
                            np.array(
                                [
                                    [4.97195909e02, 3.36259809e02, 1.31050214e01, -1.41740456e02],
                                    [-7.28050437e00, 2.14719425e02, -4.90215017e02, -2.57883151e02],
                                    [-1.17025046e-02, 9.98471159e-01, 5.40221896e-02, -4.25203639e-01],
                                    [0.00000000e00, 0.00000000e00, 0.00000000e00, 1.00000000e00],
                                ]
                            ),
                            np.array(
                                [
                                    [5.45978616e02, -2.47705944e02, -1.61356657e01, -1.84657143e02],
                                    [1.51784935e02, 1.28122911e02, -4.95917894e02, -2.77022512e02],
                                    [8.43406855e-01, 5.36312055e-01, 3.21598489e-02, -6.10371854e-01],
                                    [0.00000000e00, 0.00000000e00, 0.00000000e00, 1.00000000e00],
                                ]
                            ),
                            np.array(
                                [
                                    [1.29479337e01, 6.01261709e02, 3.10492731e01, -1.20975154e02],
                                    [-1.55728079e02, 1.28176621e02, -4.94981202e02, -2.71769902e02],
                                    [-8.23415292e-01, 5.65940098e-01, 4.12196894e-02, -5.29677094e-01],
                                    [0.00000000e00, 0.00000000e00, 0.00000000e00, 1.00000000e00],
                                ]
                            ),
                            np.array(
                                [
                                    [-3.21592898e02, -3.40289545e02, -1.05750653e01, -3.48318395e02],
                                    [-4.32931264e00, -1.78114385e02, -3.25958977e02, -2.83473696e02],
                                    [-8.33350064e-03, -9.99200442e-01, -3.91028008e-02, -1.01645350e00],
                                    [0.00000000e00, 0.00000000e00, 0.00000000e00, 1.00000000e00],
                                ]
                            ),
                            np.array(
                                [
                                    [-4.74626444e02, 3.69304577e02, 2.13056637e01, -2.50136476e02],
                                    [-1.85050206e02, -4.10162348e01, -5.00990867e02, -2.24731382e02],
                                    [-9.47586752e-01, -3.19482867e-01, 3.16948959e-03, -4.32527296e-01],
                                    [0.00000000e00, 0.00000000e00, 0.00000000e00, 1.00000000e00],
                                ]
                            ),
                            np.array(
                                [
                                    [1.14075693e02, -5.87710608e02, -2.38253717e01, -1.09040128e02],
                                    [1.77894417e02, -4.91302807e01, -5.00157067e02, -2.35298447e02],
                                    [9.24052925e-01, -3.82246554e-01, -3.70989150e-03, -4.64645142e-01],
                                    [0.00000000e00, 0.00000000e00, 0.00000000e00, 1.00000000e00],
                                ]
                            ),
                        ],
                        "pad_shape": [(384, 640, 3)] * 6,
                        "scale_factor": 1.0,
                        "flip": False,
                        "pcd_horizontal_flip": False,
                        "pcd_vertical_flip": False,
                        # 'box_mode_3d': Box3DMode.LIDAR,
                        "box_type_3d": LiDARInstance3DBoxes,
                        "img_norm_cfg": {
                            "mean": np.array([123.675, 116.28, 103.53], dtype=np.float32),
                            "std": np.array([58.395, 57.12, 57.375], dtype=np.float32),
                            "to_rgb": True,
                        },
                        "sample_idx": "3e8750f331d7499e9b5123e9eb70f2e2",
                        "prev_idx": "",
                        "next_idx": "3950bd41f74548429c0f7700ff3d8269",
                        "pcd_scale_factor": 1.0,
                        "pts_filename": "data/pcd.bin",
                        "scene_token": "fcbccedd61424f1b85dcbf8f897f9754",
                        "can_bus": np.array(
                            [
                                6.50486842e02,
                                1.81754303e03,
                                0.00000000e00,
                                1.84843146e-01,
                                1.84843146e-01,
                                1.84843146e-01,
                                1.84843146e-01,
                                8.47522666e-01,
                                1.34135536e00,
                                9.58588434e00,
                                -9.57939215e-03,
                                6.51179999e-03,
                                3.75314295e-01,
                                3.77446848e00,
                                0.00000000e00,
                                0.00000000e00,
                                3.51370076e00,
                                2.01320224e02,
                            ]
                        ),
                    }
                ]
            ]
        ],
        "gt_bboxes_3d": [
            [
                [
                    LiDARInstance3DBoxes(
                        torch.tensor(
                            [
                                [-1.49, -14.395, -1.697, 1.772, 4.294, 1.597, 2.7079, 0.033, -0.118],
                                [-5.703, -16.928, -1.723, 1.696, 4.983, 1.604, 2.7118, 0.04025, -0.02866],
                                [2.212, -16.031, -1.756, 1.764, 4.135, 1.565, 2.7047, 0.03279, -0.08034],
                                [-8.515, -16.771, -1.631, 1.709, 4.624, 1.538, 2.7163, 0.03411, 0.05917],
                                [-3.307, -19.646, -1.678, 1.785, 5.143, 1.597, 2.7124, 0.03402, -0.12001],
                                [-9.044, -12.747, -1.671, 1.763, 4.171, 1.586, 2.7187, 0.03797, -0.11996],
                                [-9.413, -9.297, -1.692, 1.777, 4.312, 1.581, 2.7201, 0.04248, -0.12339],
                                [-7.015, -14.134, -1.706, 1.748, 4.462, 1.602, 2.7136, 0.03665, -0.09436],
                                [-4.188, -16.865, -1.726, 1.729, 4.759, 1.581, 2.7098, 0.03482, -0.10607],
                                [-6.324, -18.912, -1.687, 1.773, 5.128, 1.602, 2.7149, 0.03843, -0.11187],
                                [-5.529, -15.985, -1.681, 1.711, 4.251, 1.588, 2.7102, 0.03579, -0.11788],
                            ]
                        ),
                        box_dim=9,
                    )
                ]
            ]
        ],
        "gt_labels_3d": [[[torch.tensor([8, 8, 8, 8, 8, 8, 0, 8, 8, 0, 8])]]],
        "fut_valid_flag": [torch.tensor([True])],
        "ego_his_trajs": [[torch.tensor([[[[0.0757, 4.2529], [0.0757, 4.2529]]]])]],
        "ego_fut_trajs": [[torch.zeros((1, 1, 6, 2))]],
        "ego_fut_masks": [[torch.ones((1, 1, 6))]],
        "ego_fut_cmd": [[torch.tensor([[[[1.0, 0.0, 0.0]]]])]],
        "ego_lcf_feat": [[torch.zeros((1, 1, 1, 9), dtype=torch.float32)]],
        "gt_attr_labels": [
            [
                [
                    torch.tensor(
                        [
                            [
                                -1.6727e-01,
                                -7.2574e-01,
                                -1.6727e-01,
                                -7.2574e-01,
                                -1.5026e-01,
                                -6.4606e-01,
                                -1.5114e-01,
                                -6.4654e-01,
                                -1.1047e-01,
                                -6.7103e-01,
                                -1.1046e-01,
                                -6.7097e-01,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                9.0000e00,
                                7.3898e00,
                                1.7461e01,
                                -1.8047e00,
                                -3.3320e-01,
                                -1.4408e00,
                                6.8800e-01,
                                9.4400e-01,
                                1.9040e00,
                                0.0000e00,
                                1.7488e-02,
                                1.7488e-02,
                                1.7488e-02,
                                1.7488e-02,
                                0.0000e00,
                                0.0000e00,
                            ],
                            [
                                -3.0019e-03,
                                -6.9224e-01,
                                -3.0019e-03,
                                -6.9224e-01,
                                6.8656e-02,
                                -7.2279e-01,
                                8.7407e-02,
                                -7.2163e-01,
                                7.6763e-02,
                                -7.2473e-01,
                                7.8121e-02,
                                -7.2512e-01,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                9.0000e00,
                                -1.0372e01,
                                1.4923e01,
                                -1.5259e00,
                                -4.2574e-03,
                                -1.3675e00,
                                5.7800e-01,
                                6.1300e-01,
                                1.7520e00,
                                0.0000e00,
                                8.7434e-03,
                                8.7433e-03,
                                2.6229e-02,
                                2.6228e-02,
                                -8.7426e-03,
                                -8.7427e-03,
                            ],
                            [
                                -1.4037e-01,
                                -9.3700e-01,
                                -1.0510e-01,
                                -7.8508e-01,
                                -1.6029e-01,
                                -8.3854e-01,
                                -1.5965e-01,
                                -7.8730e-01,
                                -2.5623e-01,
                                -6.8597e-01,
                                -1.6448e-01,
                                -7.5521e-01,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                9.0000e00,
                                9.8771e00,
                                2.9820e01,
                                -1.7349e00,
                                -2.7933e-01,
                                -1.8618e00,
                                7.5100e-01,
                                1.0300e00,
                                1.9750e00,
                                0.0000e00,
                                1.7488e-02,
                                1.7488e-02,
                                -6.9953e-02,
                                -1.0492e-01,
                                -3.4971e-02,
                                2.0985e-01,
                            ],
                            [
                                -9.2983e-02,
                                6.7894e-01,
                                -9.2502e-02,
                                6.7807e-01,
                                -9.5387e-02,
                                6.6988e-01,
                                -7.1772e-02,
                                6.7257e-01,
                                -1.9898e-02,
                                5.7937e-01,
                                9.1245e-02,
                                5.8810e-01,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                6.0000e00,
                                5.7166e00,
                                -1.2945e-01,
                                1.7173e00,
                                -1.8581e-01,
                                1.3567e00,
                                6.3100e-01,
                                6.1000e-01,
                                1.9290e00,
                                0.0000e00,
                                -8.7421e-03,
                                -8.7422e-03,
                                -3.4970e-02,
                                -3.4972e-02,
                                -1.3116e-01,
                                -1.1367e-01,
                            ],
                            [
                                -1.2359e-02,
                                -6.5865e-01,
                                3.7632e-03,
                                -6.5893e-01,
                                -4.3455e-02,
                                -5.6672e-01,
                                -1.2945e-01,
                                -5.5347e-01,
                                -1.7873e-01,
                                -6.5768e-01,
                                -2.3105e-01,
                                -6.2073e-01,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                9.0000e00,
                                -9.4140e00,
                                1.9118e01,
                                -1.5608e00,
                                -2.2955e-02,
                                -1.3004e00,
                                6.9700e-01,
                                4.9800e-01,
                                1.7610e00,
                                0.0000e00,
                                2.6231e-02,
                                2.6230e-02,
                                -1.5739e-01,
                                -1.5739e-01,
                                -8.7430e-02,
                                -8.7417e-02,
                            ],
                            [
                                8.0779e-02,
                                7.4353e-01,
                                8.1181e-02,
                                7.4495e-01,
                                6.3665e-02,
                                7.3262e-01,
                                1.0909e-01,
                                7.2669e-01,
                                1.9880e-01,
                                6.9847e-01,
                                1.9744e-01,
                                6.9886e-01,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                5.0000e00,
                                1.0641e01,
                                2.7106e01,
                                1.4385e00,
                                1.6085e-01,
                                1.4806e00,
                                6.6500e-01,
                                7.3600e-01,
                                1.8900e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                -6.1209e-02,
                                -6.1207e-02,
                                -1.7487e-02,
                                -1.7486e-02,
                            ],
                            [
                                -4.8422e-02,
                                4.3918e00,
                                -7.5860e-03,
                                4.3614e00,
                                -3.2758e-02,
                                4.2372e00,
                                -7.1883e-02,
                                3.9873e00,
                                -3.7413e-02,
                                3.6454e00,
                                -1.6599e-02,
                                3.0838e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                6.0000e00,
                                1.7422e-01,
                                -1.9751e01,
                                1.5892e00,
                                -9.8724e-02,
                                8.7583e00,
                                1.8710e00,
                                4.4880e00,
                                1.5150e00,
                                8.0000e00,
                                -1.7488e-02,
                                -3.4975e-03,
                                1.3990e-02,
                                -3.4975e-03,
                                -3.4975e-03,
                                -3.4975e-03,
                            ],
                            [
                                9.1897e-02,
                                -8.1954e-01,
                                2.6250e-02,
                                -6.9609e-01,
                                -7.9227e-02,
                                -7.5865e-01,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                9.0000e00,
                                -9.8804e00,
                                -2.6722e01,
                                -1.6065e00,
                                1.8469e-01,
                                -1.6281e00,
                                9.0800e-01,
                                1.1090e00,
                                2.2110e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                            ],
                            [
                                -2.7701e-03,
                                6.9534e-01,
                                -2.3412e-03,
                                6.7212e-01,
                                -2.2083e-03,
                                6.6809e-01,
                                2.2220e-02,
                                6.3216e-01,
                                2.1917e-02,
                                6.2721e-01,
                                1.3461e-02,
                                6.2600e-01,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                5.0000e00,
                                6.5109e00,
                                -1.1699e-01,
                                1.5951e00,
                                -5.5354e-03,
                                1.3895e00,
                                7.1200e-01,
                                6.0100e-01,
                                1.8910e00,
                                0.0000e00,
                                0.0000e00,
                                -8.7437e-03,
                                -8.7438e-03,
                                -1.7488e-02,
                                -1.7488e-02,
                                -3.4976e-02,
                            ],
                            [
                                2.3941e-02,
                                -5.6354e00,
                                1.7062e-01,
                                -5.6324e00,
                                1.3307e-01,
                                -5.7024e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                9.0000e00,
                                -3.2979e00,
                                -2.2998e01,
                                -1.5837e00,
                                5.0695e-02,
                                -1.1235e01,
                                2.0370e00,
                                4.9580e00,
                                1.6390e00,
                                8.0000e00,
                                2.6232e-02,
                                2.6231e-02,
                                -1.7487e-02,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                            ],
                            [
                                1.0777e-02,
                                -9.2886e-04,
                                -1.8136e-01,
                                1.4521e-02,
                                -1.3091e-02,
                                -3.2210e-02,
                                -3.7818e-02,
                                -5.6379e-02,
                                7.9810e-03,
                                9.8996e-02,
                                5.0912e-03,
                                7.0949e-02,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                1.0000e00,
                                9.0000e00,
                                -1.2731e01,
                                2.8231e01,
                                3.0605e00,
                                2.1535e-02,
                                -1.8561e-03,
                                7.3800e-01,
                                7.8300e-01,
                                1.5200e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                                0.0000e00,
                            ],
                        ]
                    )
                ]
            ]
        ],
        "map_gt_labels_3d": [[torch.zeros((7,), dtype=torch.long)]],
        "map_gt_bboxes_3d": [[None]],
    }
    tensor = torch.randn(1, 6, 3, 384, 640)
    img = []
    img.append(tensor)
    with torch.no_grad():
        model_outputs = torch_model(
            return_loss=False,
            img=img,
            img_metas=input_dict["img_metas"],
            gt_bboxes_3d=input_dict["gt_bboxes_3d"],
            gt_labels_3d=input_dict["gt_labels_3d"],
            fut_valid_flag=input_dict["fut_valid_flag"],
            ego_his_trajs=input_dict["ego_his_trajs"],
            ego_fut_trajs=input_dict["ego_fut_trajs"],
            ego_fut_cmd=input_dict["ego_fut_cmd"],
            ego_lcf_feat=input_dict["ego_lcf_feat"],
            gt_attr_labels=input_dict["gt_attr_labels"],
        )
    # ss

    parameter = create_vadv2_model_parameters_vad(
        torch_model,
        [
            False,
            img,
            input_dict["img_metas"],
            input_dict["gt_attr_labels"],
            input_dict["gt_bboxes_3d"],
            input_dict["gt_labels_3d"],
            input_dict["ego_his_trajs"],
            input_dict["ego_fut_trajs"],
            input_dict["ego_fut_cmd"],
            input_dict["ego_lcf_feat"],
        ],
        device,
    )

    tensor = ttnn.from_torch(tensor, dtype=ttnn.bfloat16, device=device, layout=ttnn.ROW_MAJOR_LAYOUT)
    img = []
    img.append(tensor)
    tt_model = tt_vad.TtVAD(
        device,
        parameter,
        use_grid_mask=False,  # set to tru for training only
        pts_voxel_layer=None,
        pts_voxel_encoder=None,
        pts_middle_encoder=None,
        pts_fusion_layer=None,
        img_backbone=True,
        pts_backbone=None,
        img_neck=True,
        pts_neck=None,
        pts_bbox_head=True,
        img_roi_head=None,
        img_rpn_head=None,
        train_cfg=None,
        test_cfg=None,
        pretrained=None,
        video_test_mode=True,
        fut_ts=6,
        fut_mode=6,
    )

    ttnn_outputs = tt_model(
        return_loss=False,
        img=img,
        img_metas=input_dict["img_metas"],
        gt_bboxes_3d=input_dict["gt_bboxes_3d"],
        gt_labels_3d=input_dict["gt_labels_3d"],
        fut_valid_flag=input_dict["fut_valid_flag"],
        ego_his_trajs=input_dict["ego_his_trajs"],
        ego_fut_trajs=input_dict["ego_fut_trajs"],
        ego_fut_cmd=input_dict["ego_fut_cmd"],
        ego_lcf_feat=input_dict["ego_lcf_feat"],
        gt_attr_labels=input_dict["gt_attr_labels"],
    )

    keys_to_check = [
        "bev_embed",
        "all_cls_scores",
        "all_bbox_preds",
        "all_traj_preds",
        "all_traj_cls_scores",
        "map_all_cls_scores",
        "map_all_bbox_preds",
        "map_all_pts_preds",
        "ego_fut_preds",
    ]

    for key in keys_to_check:
        a = torch.load(f"models/experimental/vadv2/reference/dumps/{key}.pt")
        b = torch.load(f"models/experimental/vadv2/tt/dumps/{key}.pt")
        _, msg = assert_with_pcc(a, b, 0.0)
        logger.info(f"{key}: {msg}")
